# This code is part of Qiskit.
#
# (C) Copyright IBM 2017, 2019.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""
Abstract base class for Quantum Channels.
"""

from __future__ import annotations
import copy
import sys
from abc import abstractmethod
from numbers import Number, Integral

import numpy as np

from qiskit.circuit.instruction import Instruction
from qiskit.exceptions import QiskitError
from qiskit.quantum_info.operators.linear_op import LinearOp
from qiskit.quantum_info.operators.op_shape import OpShape
from qiskit.quantum_info.operators.operator import Operator
from qiskit.quantum_info.operators.predicates import is_positive_semidefinite_matrix
from qiskit.quantum_info.operators.channel.transformations import _transform_rep
from qiskit.quantum_info.operators.channel.transformations import _to_choi
from qiskit.quantum_info.operators.channel.transformations import _to_kraus
from qiskit.quantum_info.operators.channel.transformations import _to_operator
from qiskit.quantum_info.operators.scalar_op import ScalarOp

if sys.version_info >= (3, 11):
    from typing import Self
else:
    from typing_extensions import Self


class QuantumChannel(LinearOp):
    """Quantum channel representation base class."""

    def __init__(
        self,
        data: list | np.ndarray,
        num_qubits: int | None = None,
        op_shape: OpShape | None = None,
    ):
        """Initialize a quantum channel Superoperator operator.

        Args:
            data (array or list): quantum channel data array.
            op_shape (OpShape): the operator shape of the channel.
            num_qubits (int): the number of qubits if the channel is N-qubit.

        Raises:
            QiskitError: if arguments are invalid.
        """
        self._data = data
        super().__init__(num_qubits=num_qubits, op_shape=op_shape)

    def __repr__(self):
        prefix = f"{self._channel_rep}("
        pad = len(prefix) * " "
        return "{}{},\n{}input_dims={}, output_dims={})".format(
            prefix,
            np.array2string(np.asarray(self.data), separator=", ", prefix=prefix),
            pad,
            self.input_dims(),
            self.output_dims(),
        )

    def __eq__(self, other: Self):
        """Test if two QuantumChannels are equal."""
        if not super().__eq__(other):
            return False
        return np.allclose(self.data, other.data, rtol=self.rtol, atol=self.atol)

    @property
    def data(self):
        """Return data."""
        return self._data

    @property
    def _channel_rep(self):
        """Return channel representation string"""
        return type(self).__name__

    @property
    def settings(self):
        """Return settings."""
        return {
            "data": self.data,
            "input_dims": self.input_dims(),
            "output_dims": self.output_dims(),
        }

    # ---------------------------------------------------------------------
    # LinearOp methods
    # ---------------------------------------------------------------------

    @abstractmethod
    def conjugate(self):
        r"""Return the conjugate quantum channel.

        .. note::
            This is equivalent to the matrix complex conjugate in the
            :class:`~qiskit.quantum_info.SuperOp` representation
            ie. for a channel :math:`\mathcal{E}`, the SuperOp of
            the conjugate channel :math:`\overline{{\mathcal{{E}}}}` is
            :math:`S_{\overline{\mathcal{E}^\dagger}} = \overline{S_{\mathcal{E}}}`.
        """

    @abstractmethod
    def transpose(self) -> Self:
        r"""Return the transpose quantum channel.

        .. note::
            This is equivalent to the matrix transpose in the
            :class:`~qiskit.quantum_info.SuperOp` representation,
            ie. for a channel :math:`\mathcal{E}`, the SuperOp of
            the transpose channel :math:`\mathcal{{E}}^T` is
            :math:`S_{mathcal{E}^T} = S_{\mathcal{E}}^T`.
        """

    def adjoint(self) -> Self:
        r"""Return the adjoint quantum channel.

        .. note::
            This is equivalent to the matrix Hermitian conjugate in the
            :class:`~qiskit.quantum_info.SuperOp` representation
            ie. for a channel :math:`\mathcal{E}`, the SuperOp of
            the adjoint channel :math:`\mathcal{{E}}^\dagger` is
            :math:`S_{\mathcal{E}^\dagger} = S_{\mathcal{E}}^\dagger`.
        """
        return self.conjugate().transpose()

    def power(self, n: float) -> Self:
        r"""Return the power of the quantum channel.

        Args:
            n (float): the power exponent.

        Returns:
            CLASS: the channel :math:`\mathcal{{E}} ^n`.

        Raises:
            QiskitError: if the input and output dimensions of the
                         CLASS are not equal.

        .. note::
            For non-positive or non-integer exponents the power is
            defined as the matrix power of the
            :class:`~qiskit.quantum_info.SuperOp` representation
            ie. for a channel :math:`\mathcal{{E}}`, the SuperOp of
            the powered channel :math:`\mathcal{{E}}^\n` is
            :math:`S_{{\mathcal{{E}}^n}} = S_{{\mathcal{{E}}}}^n`.
        """
        if n > 0 and isinstance(n, Integral):
            return super().power(n)

        # Conversion to superoperator
        if self._input_dim != self._output_dim:
            raise QiskitError("Can only take power with input_dim = output_dim.")
        rep = self._channel_rep
        input_dim, output_dim = self.dim
        superop = _transform_rep(rep, "SuperOp", self._data, input_dim, output_dim)
        superop = np.linalg.matrix_power(superop, n)

        # Convert back to original representation
        ret = copy.copy(self)
        ret._data = _transform_rep("SuperOp", rep, superop, input_dim, output_dim)
        return ret

    def __sub__(self, other) -> Self:
        qargs = getattr(other, "qargs", None)
        if not isinstance(other, type(self)):
            other = type(self)(other)
        return self._add(-other, qargs=qargs)

    def _add(self, other, qargs=None):
        # NOTE: this method must be overridden for subclasses
        # that don't have a linear matrix representation
        # ie Kraus and Stinespring
        if not isinstance(other, type(self)):
            other = type(self)(other)
        self._op_shape._validate_add(other._op_shape, qargs)
        other = ScalarOp._pad_with_identity(self, other, qargs)
        ret = copy.copy(self)
        ret._data = self._data + other._data
        return ret

    def _multiply(self, other):
        # NOTE: this method must be overridden for subclasses
        # that don't have a linear matrix representation
        # ie Kraus and Stinespring
        if not isinstance(other, Number):
            raise QiskitError("other is not a number")
        ret = copy.copy(self)
        ret._data = other * self._data
        return ret

    # ---------------------------------------------------------------------
    # Additional methods
    # ---------------------------------------------------------------------

    def is_cptp(self, atol: float | None = None, rtol: float | None = None) -> bool:
        """Return True if completely-positive trace-preserving (CPTP)."""
        choi = _to_choi(self._channel_rep, self._data, *self.dim)
        return self._is_cp_helper(choi, atol, rtol) and self._is_tp_helper(choi, atol, rtol)

    def is_tp(self, atol: float | None = None, rtol: float | None = None) -> bool:
        """Test if a channel is trace-preserving (TP)"""
        choi = _to_choi(self._channel_rep, self._data, *self.dim)
        return self._is_tp_helper(choi, atol, rtol)

    def is_cp(self, atol: float | None = None, rtol: float | None = None) -> bool:
        """Test if Choi-matrix is completely-positive (CP)"""
        choi = _to_choi(self._channel_rep, self._data, *self.dim)
        return self._is_cp_helper(choi, atol, rtol)

    def is_unitary(self, atol: float | None = None, rtol: float | None = None) -> bool:
        """Return True if QuantumChannel is a unitary channel."""
        try:
            op = self.to_operator()
            return op.is_unitary(atol=atol, rtol=rtol)
        except QiskitError:
            return False

    def to_operator(self) -> Operator:
        """Try to convert channel to a unitary representation Operator."""
        mat = _to_operator(self._channel_rep, self._data, *self.dim)
        return Operator(mat, self.input_dims(), self.output_dims())

    def to_instruction(self) -> Instruction:
        """Convert to a Kraus or UnitaryGate circuit instruction.

        If the channel is unitary it will be added as a unitary gate,
        otherwise it will be added as a kraus simulator instruction.

        Returns:
            qiskit.circuit.Instruction: A kraus instruction for the channel.

        Raises:
            QiskitError: if input data is not an N-qubit CPTP quantum channel.
        """

        # Check if input is an N-qubit CPTP channel.
        num_qubits = int(np.log2(self._input_dim))
        if self._input_dim != self._output_dim or 2**num_qubits != self._input_dim:
            raise QiskitError(
                "Cannot convert QuantumChannel to Instruction: channel is not an N-qubit channel."
            )
        if not self.is_cptp():
            raise QiskitError("Cannot convert QuantumChannel to Instruction: channel is not CPTP.")
        # Next we convert to the Kraus representation. Since channel is CPTP we know
        # that there is only a single set of Kraus operators
        kraus, _ = _to_kraus(self._channel_rep, self._data, *self.dim)
        # If we only have a single Kraus operator then the channel is
        # a unitary channel so can be converted to a UnitaryGate. We do this by
        # converting to an Operator and using its to_instruction method
        if len(kraus) == 1:
            return Operator(kraus[0]).to_instruction()
        return Instruction("kraus", num_qubits, 0, kraus)

    def _is_cp_helper(self, choi, atol, rtol):
        """Test if a channel is completely-positive (CP)"""
        if atol is None:
            atol = self.atol
        if rtol is None:
            rtol = self.rtol
        return is_positive_semidefinite_matrix(choi, rtol=rtol, atol=atol)

    def _is_tp_helper(self, choi, atol, rtol):
        """Test if Choi-matrix is trace-preserving (TP)"""
        if atol is None:
            atol = self.atol
        if rtol is None:
            rtol = self.rtol
        # Check if the partial trace is the identity matrix
        d_in, d_out = self.dim
        mat = np.trace(np.reshape(choi, (d_in, d_out, d_in, d_out)), axis1=1, axis2=3)
        tp_cond = np.linalg.eigvalsh(mat - np.eye(len(mat)))
        zero = np.isclose(tp_cond, 0, atol=atol, rtol=rtol)
        return np.all(zero)

    def _format_state(self, state, density_matrix=False):
        """Format input state so it is statevector or density matrix"""
        state = np.array(state)
        shape = state.shape
        ndim = state.ndim
        if ndim > 2:
            raise QiskitError("Input state is not a vector or matrix.")
        # Flatten column-vector to vector
        if ndim == 2:
            if shape[1] != 1 and shape[1] != shape[0]:
                raise QiskitError("Input state is not a vector or matrix.")
            if shape[1] == 1:
                # flatten column-vector to vector
                state = np.reshape(state, shape[0])
        # Convert statevector to density matrix if required
        if density_matrix and ndim == 1:
            state = np.outer(state, np.transpose(np.conj(state)))
        return state

    @abstractmethod
    def _evolve(self, state, qargs=None):
        """Evolve a quantum state by the quantum channel.

        Args:
            state (DensityMatrix or Statevector): The input state.
            qargs (list): a list of quantum state subsystem positions to apply
                           the quantum channel on.

        Returns:
            DensityMatrix: the output quantum state as a density matrix.

        Raises:
            QiskitError: if the quantum channel dimension does not match the
                         specified quantum state subsystem dimensions.
        """
        pass

    @classmethod
    def _init_transformer(cls, data):
        """Convert input into a QuantumChannel subclass object or Operator object"""
        # This handles common conversion for all QuantumChannel subclasses.
        # If the input is already a QuantumChannel subclass it will return
        # the original object
        if isinstance(data, QuantumChannel):
            return data
        if hasattr(data, "to_quantumchannel"):
            # If the data object is not a QuantumChannel it will give
            # preference to a 'to_quantumchannel' attribute that allows
            # an arbitrary object to define its own conversion to any
            # quantum channel subclass.
            return data.to_quantumchannel()
        if hasattr(data, "to_channel"):
            # TODO: this 'to_channel' method is the same case as the above
            # but is used by current version of Aer. It should be removed
            # once Aer is nupdated to use `to_quantumchannel`
            # instead of `to_channel`,
            return data.to_channel()
        # Finally if the input is not a QuantumChannel and doesn't have a
        # 'to_quantumchannel' conversion method we try and initialize it as a
        # regular matrix Operator which can be converted into a QuantumChannel.
        return Operator(data)
